{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Y11F4x7U4skJ"
   },
   "outputs": [],
   "source": [
    "# Copyright 2022 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CrkAxQOUkkJA"
   },
   "source": [
    "# First, some setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "id": "kUX6_56tDC1A"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-08 16:48:20.742581: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import functools\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import jaxline\n",
    "from typing import Generator, Mapping, Sequence, Text\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import time\n",
    "import haiku as hk\n",
    "import IPython\n",
    "from PIL import Image\n",
    "import datetime\n",
    "\n",
    "from perceiver_ar import experiment\n",
    "from perceiver_ar import perceiver_ar_model\n",
    "from perceiver_ar import dataset\n",
    "from perceiver_ar import sample_utils\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
       " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
       " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
       " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
       " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
       " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
       " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
       " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CO_GI8OoCz43"
   },
   "source": [
    "# Restore a full experiment from a checkpoint\n",
    "\n",
    "Uncomment the section related to the experiment you want to load."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6z9FQg4y4skP",
    "outputId": "edf4121b-b415-4577-f397-b08b84bdc4a1"
   },
   "outputs": [],
   "source": [
    "## Synthetic Copy Task, 32 positions (suitable for local CPU training)\n",
    "# modality = 'raw'\n",
    "# input_sequence_init = 'mirror_input'\n",
    "# sweep_name = 'random_mirrored_32'\n",
    "# checkpoint_base = Path('/tmp/perceiver_ar')\n",
    "\n",
    "## Synthetic Copy Task, 131k positions\n",
    "# $ mkdir perceiver-ar-checkpoints\n",
    "# $ gsutil cp gs://perceiver-ar/checkpoints/random_mirrored_131072.zip perceiver-ar-checkpoints\n",
    "# $ unzip perceiver-ar-checkpoints/random_mirrored_131072.zip -d perceiver-ar-checkpoints\n",
    "modality = 'raw'\n",
    "input_sequence_init = 'mirror_input'\n",
    "sweep_name = 'random_mirrored_131072'\n",
    "checkpoint_base = Path.home() / 'perceiver-ar-checkpoints/random_mirrored_131072'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DC_0lmzQiefa",
    "outputId": "bbdf9089-b11c-40ed-fe4a-6bc7cdf7836e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model will be loaded from: /home/fjord/perceiver-ar-checkpoints/random_mirrored_131072/models/latest/step_25000\n"
     ]
    }
   ],
   "source": [
    "checkpoint_dir = sorted((checkpoint_base / 'models/latest').iterdir(), key=lambda x: x.stat().st_mtime)[-1]\n",
    "\n",
    "print(f'Model will be loaded from: {checkpoint_dir}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "pp_6H5t94skQ",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "config = experiment.get_config(sweep_name)\n",
    "experiment.restore_state_to_in_memory_checkpointer(checkpoint_dir, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DYI4md87D4l2",
    "outputId": "52046667-ed78-46b6-dd72-b0b0ffb62c30"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restored step 25000\n"
     ]
    }
   ],
   "source": [
    "checkpointer = jaxline.platform.create_checkpointer(config, 'eval')\n",
    "state = checkpointer.get_experiment_state('latest')\n",
    "\n",
    "# Add the fields you want to restore here.\n",
    "# Must include experiment_module.\n",
    "state.global_step = 0\n",
    "state.experiment_module = experiment.Experiment(\n",
    "    'eval', jax.random.PRNGKey(config.random_seed),\n",
    "    **config.experiment_kwargs)\n",
    "\n",
    "checkpointer.restore('latest')\n",
    "exp_params = jaxline.utils.get_first(state.experiment_module._params)\n",
    "exp_state = jaxline.utils.get_first(state.experiment_module._state)\n",
    "\n",
    "max_context_length = config.experiment_kwargs.config.data.max_context_length\n",
    "# We want to store max_context_length plus the final prediction.\n",
    "max_events_length = max_context_length + 1\n",
    "\n",
    "events = np.zeros([1, max_events_length], np.int32)\n",
    "events[:, 0] = dataset.SOS_ID\n",
    "\n",
    "print('Restored step', state.global_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_w9G_ZD-kHQS"
   },
   "source": [
    "# Now run the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "a14UJG7q6Qif",
    "outputId": "7321aa65-4bf6-4794-97c6-7b6095b72a5b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device count 8\n",
      "Using input_sequence_init `mirror_input`. Setting start_step to 65536.\n"
     ]
    }
   ],
   "source": [
    "#@title Set up the input sequence (NB: start_step ignored for some init types)\n",
    "\n",
    "batch_size =   1#@param {type:\"integer\"}\n",
    "start_step = 1#@param {type:\"integer\"}\n",
    "\n",
    "device_count = jax.local_device_count()\n",
    "print('device count', device_count)\n",
    "\n",
    "max_context_length = config.experiment_kwargs.config.data.max_context_length\n",
    "# We want to store max_context_length plus the final prediction.\n",
    "max_events_length = max_context_length + 1\n",
    "\n",
    "if input_sequence_init == 'zeros':\n",
    "  def gen_initial_events():\n",
    "    events = np.zeros([device_count, batch_size, max_events_length], np.int32)\n",
    "    # Add expected SOS prompt.\n",
    "    events[:, :, 0] = dataset.SOS_ID\n",
    "    return events\n",
    "elif input_sequence_init == 'mirror_input':\n",
    "  # Account for the SOS\n",
    "  seq_len = config.experiment_kwargs.config.data.max_context_length - 2\n",
    "  seq_len = seq_len // 2\n",
    "\n",
    "  start_step = seq_len + 1\n",
    "  print('Using input_sequence_init `mirror_input`. Setting '\n",
    "        f'start_step to {start_step}.')\n",
    "\n",
    "  def gen_initial_events():\n",
    "    # Initialize with a random MirroredDataset sequence.\n",
    "    events = np.zeros([device_count, batch_size, max_events_length], np.int32)\n",
    "    rng = jax.random.PRNGKey(0)\n",
    "    forward_sequence = jax.random.randint(\n",
    "        rng, [device_count, batch_size, seq_len], \n",
    "        minval=dataset.NUM_RESERVED_TOKENS, \n",
    "        maxval=256 + dataset.NUM_RESERVED_TOKENS, \n",
    "        dtype=jnp.int32)\n",
    "  \n",
    "    # Force start_step to half the sequence length:\n",
    "    events[:, :, 1:seq_len+1] = forward_sequence\n",
    "    # Add expected SOS prompt.\n",
    "    events[:, :, 0] = dataset.SOS_ID\n",
    "    return events\n",
    "\n",
    "if start_step < 1:\n",
    "  raise ValueError('start_step must be >= 1 to account for the SOS token.')\n",
    "\n",
    "# Make sure start_step doesn't exceed the maximum context of the model.\n",
    "if start_step > max_context_length:\n",
    "  print(f'Warning: start_step {start_step} exceeds '\n",
    "        f'max_context_length used at training {max_context_length}.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1XO4kwRiX3Uy",
    "outputId": "8afce3d4-cafd-4a75-d3dc-892337b28955"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using default max_context_length for memory: 131072\n",
      "Using default z_index_dim for memory: 1024\n"
     ]
    }
   ],
   "source": [
    "#@title Memory parameters (set to 0 to use model defaults)\n",
    "use_memory = True #@param {type:\"boolean\"}\n",
    "max_context_length_memory = 0 #@param {type:\"integer\"}\n",
    "z_index_dim_memory =  0#@param {type: \"integer\"}\n",
    "\n",
    "model_kwargs = config.experiment_kwargs.config.model.perceiver_ar_kwargs\n",
    "# These values can be adjusted, but set to defaults if not specified\n",
    "if use_memory:\n",
    "  if max_context_length_memory == 0:\n",
    "    print('Using default max_context_length for memory: '\n",
    "          f'{config.max_context_length}')\n",
    "    max_context_length_memory = config.max_context_length\n",
    "  if z_index_dim_memory == 0:\n",
    "    print(f'Using default z_index_dim for memory: {model_kwargs.z_index_dim}')\n",
    "    z_index_dim_memory = model_kwargs.z_index_dim\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cellView": "form",
    "id": "06cUc_uRL903"
   },
   "outputs": [],
   "source": [
    "#@title Set up sampling\n",
    "\n",
    "@functools.partial(jax.jit, static_argnums=(2, 3, 4, 5))\n",
    "def sample_sequences(\n",
    "    events, rng, \n",
    "    start_step=1, \n",
    "    num_steps=-1,\n",
    "    temperature=1.0, \n",
    "    top_p=1.0):\n",
    "  if num_steps < 0:\n",
    "    num_steps = events.shape[1] - start_step + 1\n",
    "\n",
    "  # Start at position 1 because SOS must be at position 0.\n",
    "  upper = min(num_steps + start_step, events.shape[1])\n",
    "  print('start_step', start_step)\n",
    "  print('upper', upper)\n",
    "\n",
    "  if modality == 'image_w_positions':\n",
    "    x_event_idxs = jnp.reshape(\n",
    "        jnp.broadcast_to(jnp.arange(64)[None, :, None] + 1, [64, 64, 3]), [-1])\n",
    "    y_event_idxs = jnp.reshape(\n",
    "        jnp.broadcast_to(jnp.arange(64)[:, None, None] + 1, [64, 64, 3]), [-1])\n",
    "    channel_event_idxs = jnp.reshape(\n",
    "        jnp.broadcast_to(jnp.array([1, 2, 3]), [64, 64, 3]), [-1])\n",
    "\n",
    "    event_idxs = jnp.stack(\n",
    "        [x_event_idxs, y_event_idxs, channel_event_idxs], axis=1)\n",
    "\n",
    "    # Account for SOS.\n",
    "    event_idxs = jnp.concatenate(\n",
    "        [jnp.ones([1, 3], dtype=jnp.int32), event_idxs + 1], axis=0)\n",
    "\n",
    "    # Pad remaining positions.\n",
    "    event_idxs = jnp.pad(\n",
    "        event_idxs, [[0, events.shape[1] - event_idxs.shape[0]], [0, 0]])\n",
    "\n",
    "    event_idxs = jnp.broadcast_to(event_idxs, events.shape + (3,))\n",
    "  else:\n",
    "    # Otherwise, assume linear event indices.\n",
    "    event_idxs = jnp.arange(start=1, stop=events.shape[1] + 1)\n",
    "    event_idxs = jnp.expand_dims(event_idxs, axis=-1)\n",
    "    event_idxs = jnp.broadcast_to(event_idxs, events.shape + (1,))\n",
    "  \n",
    "  model_kwargs = config.experiment_kwargs.config.model.perceiver_ar_kwargs\n",
    "\n",
    "  if use_memory:\n",
    "    # Zero-initialize the memory.\n",
    "    memory = sample_utils.initialize_memory(\n",
    "        batch_size=batch_size,\n",
    "        num_transformers_per_block=model_kwargs.num_transformers_per_block,\n",
    "        num_cross_attend_heads=model_kwargs.num_cross_attend_heads,\n",
    "        num_transformer_heads=model_kwargs.num_transformer_heads,\n",
    "        num_z_channels=model_kwargs.num_z_channels,\n",
    "        max_context_length_memory=max_context_length_memory,\n",
    "        z_index_dim_memory=z_index_dim_memory,\n",
    "        position_encoding_type=model_kwargs.position_encoding_type,\n",
    "        memory_type='fixed_size_kv')\n",
    "\n",
    "  # Build the parameters reused between model calls. \n",
    "  sample_position_args = dict(\n",
    "      event_idxs=event_idxs,\n",
    "      modality=modality,\n",
    "      temperature=temperature,\n",
    "      top_p=top_p,\n",
    "      use_memory=use_memory,\n",
    "  )\n",
    "\n",
    "  # Package the (constant) params and state with the model.\n",
    "  forward_fn = functools.partial(\n",
    "      state.experiment_module.forward.apply,\n",
    "      exp_params,\n",
    "      exp_state,\n",
    "  )\n",
    "\n",
    "  if use_memory and start_step > 1:\n",
    "    # Run forward the model with long context for one step to initialize \n",
    "    # the memory and get the first sample.\n",
    "    events, rng, memory = sample_utils.sample_position(\n",
    "        i=start_step,\n",
    "        events_rng_memory=(events, rng, memory),\n",
    "        forward_fn=forward_fn,\n",
    "        condition_on='all_previous',\n",
    "        **sample_position_args)\n",
    "    start_step += 1\n",
    "\n",
    "  if use_memory:\n",
    "    condition_on_loop = 'most_recent_only'\n",
    "  else:\n",
    "    condition_on_loop = 'all'\n",
    "  sample_positions_loop = functools.partial(\n",
    "      sample_utils.sample_position,\n",
    "      # i, events_rng_memory supplied by caller.\n",
    "      forward_fn=forward_fn,\n",
    "      condition_on=condition_on_loop,\n",
    "      **sample_position_args)\n",
    "\n",
    "  if use_memory:\n",
    "    inputs = (events, rng, memory)\n",
    "  else:\n",
    "    inputs = (events, rng)\n",
    "  \n",
    "  outputs = jax.lax.fori_loop(\n",
    "      start_step, upper, sample_positions_loop, inputs)\n",
    "  \n",
    "  return outputs\n",
    "\n",
    "sample_sequences = jax.pmap(\n",
    "    sample_sequences,\n",
    "    static_broadcasted_argnums=(2, 3, 4, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8tVZgwjfC3Hp",
    "outputId": "98ce9b50-e039-4c74-f636-1c679301770c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "starting generation 1654707153.3989837\n",
      "start_step 65536\n",
      "upper 131073\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tcmalloc: large alloc 1793105920 bytes == 0x56531712a000 @  0x7fbc47d09680 0x7fbc47d2a824 0x56520359f5f3 0x5652035b9db8 0x5652035df86c 0x5652035ee6f1 0x5652036940b5 0x56520369410c 0x56520355e59c 0x56520362a826 0x5652035e4233 0x5652035b5b3e 0x7fbb70ade2ab 0x7fbb70ade40e 0x7fbb70317b27 0x7fbb704f2194 0x7fbb7040fc79 0x7fbb70416197 0x7fbb70416a5b 0x7fbb7041721b 0x7fbb704785e6 0x7fbb70328fab 0x7fbb703292a6 0x7fbb70416197 0x7fbb70416a5b 0x7fbb7041721b 0x7fbb7044d96d 0x7fbb7044d9c6 0x7fbb70416197 0x7fbb704175ba 0x7fbb7041c2cb\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "generation complete 1654710158.1792338\n",
      "3004.7802500724792 seconds\n",
      "50.07967083454132 minutes\n"
     ]
    }
   ],
   "source": [
    "#@title Sample the sequence\n",
    "\n",
    "#@markdown `num_steps`=-1 will fill the entire buffer.\n",
    "num_steps =   -1#@param {type:\"integer\"}\n",
    "temperature =   1.0#@param {type:\"number\"}\n",
    "top_p =   1.0#@param {type:\"number\"}\n",
    "random_seed = 0#@param{type:\"integer\"}\n",
    "\n",
    "rng = jax.random.PRNGKey(random_seed)\n",
    "rng = jax.random.split(rng, device_count)\n",
    "\n",
    "events = gen_initial_events()\n",
    "\n",
    "if 'outputs' in locals():\n",
    "  del outputs\n",
    "if 'memory' in locals():\n",
    "  del memory\n",
    "if 'seq' in locals():\n",
    "  del seq\n",
    "\n",
    "tick = time.time()\n",
    "print('starting generation', tick)\n",
    "\n",
    "outputs = sample_sequences(\n",
    "    events, rng, start_step, num_steps, temperature, top_p)\n",
    "\n",
    "outputs = jax.tree_map(lambda x: x.block_until_ready(), outputs)\n",
    "tock = time.time()\n",
    "print('generation complete', tock)\n",
    "print(tock - tick, 'seconds')\n",
    "print((tock - tick) / 60, 'minutes')\n",
    "\n",
    "if use_memory:\n",
    "  events, rng, memory = outputs\n",
    "else:\n",
    "  events, rng = outputs\n",
    "\n",
    "# reshape to remove device axis\n",
    "events = events.reshape((np.prod(events.shape[:2]),) + events.shape[2:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wozd0BFib6q8",
    "outputId": "9314c6dd-cf74-4957-dbcd-d12bd1d8fcf1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Last inputs (reversed):\n",
      " [[153 159 239 ... 246 175 167]\n",
      " [ 20  73  23 ... 215  76 103]\n",
      " [ 21  15 176 ... 179 218  22]\n",
      " ...\n",
      " [ 31 227 189 ...  35 104  59]\n",
      " [223 235 122 ...  93 132 241]\n",
      " [244   6 186 ...  36  67 200]]\n",
      "First outputs:\n",
      " [[153 159 239 ... 246 175 167]\n",
      " [ 20  73  23 ... 215  76 103]\n",
      " [ 21  15 176 ... 179 218  22]\n",
      " ...\n",
      " [ 31 227 189 ...  35 104  59]\n",
      " [223 235 122 ...  93 132 241]\n",
      " [244   6 186 ...  36  67 200]]\n",
      "Number of matches [65535 65535 65535 65535 65535 65535 65535 65535]\n",
      "All match? True\n"
     ]
    }
   ],
   "source": [
    "#@title Do the inputs and outputs of the mirrored_input test match?\n",
    "if input_sequence_init == 'mirror_input':\n",
    "  if num_steps < 0:\n",
    "    num_steps = events.shape[1] - start_step + 1\n",
    "\n",
    "  start_idx = max(start_step - num_steps, 1)\n",
    "  end_idx = min(start_step+num_steps, max_context_length-1) \n",
    "  last_inputs = events[:, start_idx:start_step][:, ::-1]\n",
    "  first_outputs = events[:, start_step:end_idx]\n",
    "  print(f'Last inputs (reversed):\\n {last_inputs}')\n",
    "  print(f'First outputs:\\n {first_outputs}')\n",
    "  print('Number of matches', (first_outputs == last_inputs).sum(axis=-1))\n",
    "  print('All match?', np.all(first_outputs == last_inputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "b6bNeGfszNGE",
    "outputId": "2a61748f-7022-43dc-9013-573fd86ededc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[167 175 246 ... 246 175 167]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[103  76 215 ... 215  76 103]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[ 22 218 179 ... 179 218  22]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[ 19  29 154 ... 154  29  19]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[  4 249  54 ...  54 249   4]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[ 59 104  35 ...  35 104  59]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[241 132  93 ...  93 132 241]\n",
      "#####\n",
      "Found EOS at index 131071, truncating.\n",
      "Found SOS at index 0 as expected, removing.\n",
      "Shape: (131070,)\n",
      "[200  67  36 ...  36  67 200]\n",
      "#####\n"
     ]
    }
   ],
   "source": [
    "#@title Visualize or play the sampled sequence\n",
    "\n",
    "inference_dir = checkpoint_base / 'inference' / datetime.datetime.now().isoformat().replace(\":\", \"\")\n",
    "im_timestamp = time.time()\n",
    "\n",
    "for i, seq in enumerate(events):\n",
    "  eos_idx = jnp.argmax(seq == dataset.EOS_ID)\n",
    "  if eos_idx:\n",
    "    print(f'Found EOS at index {eos_idx}, truncating.')\n",
    "    seq = seq[:eos_idx]\n",
    "  else:\n",
    "    print('No EOS token found.')\n",
    "\n",
    "  if seq[0] == dataset.SOS_ID:\n",
    "    print(f'Found SOS at index 0 as expected, removing.')\n",
    "    seq = seq[1:]\n",
    "  else:\n",
    "    print(f'WARNING: SOS not found at index 0. This should not happen.')\n",
    "\n",
    "  if modality in ['image', 'image_w_positions']:\n",
    "    seq = seq - dataset.NUM_RESERVED_TOKENS\n",
    "    rem = len(seq) % 3\n",
    "    if rem > 0:\n",
    "      print(f'Truncating {rem} position(s) to ensure multiple of 3')\n",
    "      seq = seq[:-rem]\n",
    "    seq = seq.reshape(-1, 3)\n",
    "    if modality == 'image':\n",
    "      seq -= jnp.broadcast_to([0, 256, 512], seq.shape)\n",
    "    rem = len(seq) % 64\n",
    "    if rem > 0:\n",
    "      print(f'Truncating {rem} tuple(s) to ensure multiple of 64')\n",
    "      seq = seq[:-rem]\n",
    "    seq = seq.reshape(-1, 64, 3)\n",
    "\n",
    "    seq = jnp.where(seq >= 0, seq, 0)\n",
    "    print('-----')\n",
    "    print(seq.shape)\n",
    "    plt.imshow(seq)\n",
    "    plt.show()\n",
    "\n",
    "    inference_dir.mkdir(parents=True, exist_ok=True)\n",
    "    im_filename = f'im_{i}.png'\n",
    "    im_path = inference_dir / im_filename\n",
    "    img = Image.fromarray(np.array(seq, dtype=np.uint8), mode='RGB')\n",
    "    print('saving to', im_path)\n",
    "    IPython.display.display(img)\n",
    "    img.save(im_path)\n",
    "  elif modality == 'raw':\n",
    "    print('Shape:', seq.shape)\n",
    "    print(seq)\n",
    "    print('#####')\n",
    "  else:\n",
    "    raise ValueError(f'Unknown modality: {modality}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "name": "inference.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
